#include <torch/extension.h>

void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
                            const float lr, const float beta1, const float beta2, const float epsilon, const int step,
                            const int bias_correction, const float weight_decay, const int grad_averaging,
                            const int mode, const float global_grad_norm, const float max_grad_norm);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer",
        py::call_guard<py::gil_scoped_release>());
}
